Conversation
ShuffleNet model
ToucheSir
left a comment
There was a problem hiding this comment.
Thanks for the contribution! This is a great start, and the next steps would be adding tests + better matching the code style of the rest of the repo.
src/convnets/shufflenet.jl
Outdated
| function ChannelShuffle(x::Array{Float32, 4}, g::Int) | ||
| width, height, channels, batch = size(x) | ||
| channels_per_group = channels÷g | ||
| if (channels % g) == 0 |
There was a problem hiding this comment.
| if (channels % g) == 0 | |
| if channels % g == 0 |
We have a JuliaFormatter config in this repo, so make sure to run that before pushing your code.
src/convnets/shufflenet.jl
Outdated
| - `channels`: number of channels | ||
| - `groups`: number of groups | ||
| """ | ||
| function ChannelShuffle(x::Array{Float32, 4}, g::Int) |
There was a problem hiding this comment.
| function ChannelShuffle(x::Array{Float32, 4}, g::Int) | |
| function channel_shuffle(x::AbstractArray{Float32, 4}, g::Int) |
This type constraint is too restrictive. If ChannelShuffle works for all number types than it should reflect that. Generally all utility functions in Metalhead need to be GPU-compatible too. The renaming is a suggestion for how to make this function more "Julian", since it's not a callable type (which would be PascalCase) but a plain function. Lastly, how does this handle 3D inputs?
There was a problem hiding this comment.
I didn't think about it when writing the function, so for a 3D inputs, a batch of grey images, would be necessary to artificially create a channel dimension.
src/convnets/shufflenet.jl
Outdated
| BatchNorm(mid_channels), | ||
| NNlib.relu, |
There was a problem hiding this comment.
| BatchNorm(mid_channels), | |
| NNlib.relu, | |
| BatchNorm(mid_channels, relu), |
relu is already in scope because of using NNlib and fusing it into the preceeding norm is slightly more efficient. Also, is the activation function not configurable for ShuffleNet?
src/convnets/shufflenet.jl
Outdated
| m = Chain(Conv((1,1), in_channels => mid_channels; groups,pad=SamePad()), | ||
| BatchNorm(mid_channels), | ||
| NNlib.relu, | ||
| x -> ChannelShuffle(x, groups), |
There was a problem hiding this comment.
| x -> ChannelShuffle(x, groups), | |
| Base.Fix2(channel_shuffle, groups), |
Will be easier on the compiler.
src/convnets/shufflenet.jl
Outdated
| NNlib.relu) | ||
|
|
||
| if downsample | ||
| m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) |
There was a problem hiding this comment.
| m = Parallel((mx, x) -> cat(mx, x, dims=3),m, MeanPool((3,3); pad=SamePad(), stride=2)) | |
| m = Parallel(cat_channels, m, MeanPool((3,3); pad=SamePad(), stride=2)) |
We have cat_channels for this exact case.
src/convnets/shufflenet.jl
Outdated
|
|
||
| model = Chain(features...) | ||
|
|
||
| return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) |
There was a problem hiding this comment.
| return Chain(model, GlobalMeanPool(), Flux.flatten, Dense(in_channels => num_classes)) | |
| return Chain(model, GlobalMeanPool(), MLUtils.flatten, Dense(in_channels => num_classes)) |
The general modus operandi of this library has been to create named types for the top-level model and wrap the underlying Chain with them. You can see this pattern in the files for any of the other exported models.
For the suggestion. flatten is only imported and not defined in Flux. It's preferable to use a symbol from the library that actually defined when that library is available (which MLUtils is, being a dep of Metalhead).
There was a problem hiding this comment.
Could I see an example? Sorry, I'm still a newbie using Julia, I looked to the rest of convnets and tried to code with a similar style, but there are still things I that still haven't fully understood.
Better matching the code style of the rest of Metalhead
|
I made the suggested changes |
|
Thanks for the updates. On a quick skim nothing stands out to me, can you add it to the test suite to finish off the PR? |
corrected typo
added missing includes
I'm working on this implementation of ShuffleNet from https://arxiv.org/abs/1707.01083.